{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2b408fea",
   "metadata": {},
   "source": [
    "## This notebook process and generate the numerical simulation results for the numerical section of the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2587ed0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.io import loadmat\n",
    "import tensorflow as tf\n",
    "import os\n",
    "\n",
    "# These are used for plotting later on\n",
    "from cycler import cycler\n",
    "import matplotlib as mpl\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Edit this line to point to the location of the data folder on your system.\n",
    "data_folder_prefix = './data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53b4ed90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the ground truth for the test dataset, this\n",
    "# would be the same image throughout\n",
    "prefix = data_folder_prefix + '/Simulated/R_Sweep/'\n",
    "test_images = np.load(prefix + 'test_images.npy')\n",
    "# get the phase of the complex objects\n",
    "test_images = np.angle(test_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5bd7a10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this plot the ground truth\n",
    "plt.rcParams['figure.figsize'] = [20, 20]\n",
    "fig = plt.figure()\n",
    "fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "for i in range(9):\n",
    "    ax = fig.add_subplot(3, 3, i +1)\n",
    "    ax.axes.xaxis.set_visible(False)\n",
    "    ax.axes.yaxis.set_visible(False)\n",
    "    plt.imshow(test_images[i], cmap = 'viridis')\n",
    "    plt.clim(np.min(test_images), np.max(test_images))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e144060",
   "metadata": {},
   "source": [
    "## Each cell below get the SSIM values for different R values using different framework"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bac37616",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# load the MS-SSIM loss function\n",
    "ssim = tf.image.ssim_multiscale\n",
    "\n",
    "# generative not pretrained network results\n",
    "ssim_mean1 = []\n",
    "ssim_std1 = []\n",
    "\n",
    "photon_level = 1e4\n",
    "\n",
    "for idx, R in enumerate([0.25, 0.5, 1.0, 2.0]):\n",
    "    \n",
    "    # this list save all the output for the hyperparameter search\n",
    "    pic = []\n",
    "    \n",
    "    # this list save the ssim for the hyperparameter search at the given R\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]\n",
    "    \n",
    "    # loop over all the alpha values\n",
    "    for alpha in alphas:\n",
    "        # load the generative not pretrained network results\n",
    "        empty = 0\n",
    "        for string in ['2021-09-27', '2021-09-28', '2021-09-29', '2021-09-30']:\n",
    "            #print(str(string))\n",
    "            if os.path.isfile(str(string) + '-not-pretrained-alpha-%0.2f-R-%0.2f.mat' % (alpha, R)):\n",
    "                matfile = loadmat(str(string) + '-not-pretrained-alpha-%0.2f-R-%0.2f.mat' % (alpha, R))\n",
    "                empty += 1\n",
    "                \n",
    "        if empty != 1:\n",
    "            print(\"missing files for R-%0.2f-alpha-%0.2f\" %(R, alpha))\n",
    "            break\n",
    "        rec_test_output = matfile['rec_test_output']\n",
    "\n",
    "        pic.append(rec_test_output)\n",
    "    \n",
    "        loss_ssim_list = []\n",
    "        \n",
    "        #calculate the MS-SSIM object by object\n",
    "        for gen, true in zip(rec_test_output, test_images):\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            loss_ssim_list.append(ssim_data)   \n",
    "        \n",
    "        ssim_temp.append(np.mean(loss_ssim_list))\n",
    "        std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('the current R is:', R)\n",
    "    idx = np.argmax(ssim_temp)\n",
    "    print('the best alpha is:', alphas[idx])\n",
    "    # only append the best results for the given R\n",
    "    ssim_mean1.append(ssim_temp[idx])\n",
    "    ssim_std1.append(std_temp[idx])\n",
    "    \n",
    "    \n",
    "    # this plot the best resulst for the given R\n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "    \n",
    "    for i in range(9):\n",
    "        ax = fig.add_subplot(3, 3, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(pic[idx][i], cmap = 'viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f834322",
   "metadata": {},
   "source": [
    "# generative pretrained results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "297c8c27",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# generative pretrained results\n",
    "# the process is very similar as above\n",
    "ssim_mean2 = []\n",
    "ssim_std2 = []\n",
    "\n",
    "photon_level = 1e4\n",
    "\n",
    "for R in [0.25, 0.5, 1.0, 2.0]:\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    pic = []\n",
    "    \n",
    "    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]\n",
    "    # loop over all the alpha values\n",
    "    for alpha in alphas:\n",
    "        # load the generative not pretrained network results\n",
    "        \n",
    "        empty = 0\n",
    "        for string in ['2021-09-27', '2021-09-28', '2021-09-29', '2021-09-30']:\n",
    "            #print(str(string))\n",
    "            if os.path.isfile(str(string) + '-pretrained-alpha-%0.2f-R-%0.2f.mat' % (alpha, R)):\n",
    "                matfile = loadmat(str(string) + '-pretrained-alpha-%0.2f-R-%0.2f.mat' % (alpha, R))\n",
    "                empty += 1\n",
    "                \n",
    "        if empty != 1:\n",
    "            print(\"missing files for R-%0.2f-alpha-%0.2f\" %(R, alpha))\n",
    "            break\n",
    "            \n",
    "        rec_test_output = matfile['rec_test_output']\n",
    "        \n",
    "        pic.append(rec_test_output)\n",
    "    \n",
    "        loss_ssim_list = []\n",
    "        for gen, true in zip(rec_test_output, test_images):\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            loss_ssim_list.append(ssim_data)   \n",
    "        \n",
    "        ssim_temp.append(np.mean(loss_ssim_list))\n",
    "        std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('the current R is:', R)\n",
    "    idx = np.argmax(ssim_temp)\n",
    "    print('the best alpha is:', alphas[idx])\n",
    "    ssim_mean2.append(ssim_temp[idx])\n",
    "    ssim_std2.append(std_temp[idx])\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(9):\n",
    "        ax = fig.add_subplot(3, 3, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(pic[idx][i], cmap = 'viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90f5d1cd",
   "metadata": {},
   "source": [
    "# not generative not pretrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d688e8f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# not generative not pretrained\n",
    "ssim_mean3 = []\n",
    "ssim_std3 = []\n",
    "\n",
    "photon_level = 1e4\n",
    "\n",
    "for R in [0.25, 0.5, 1, 2]:\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    matfile = loadmat('2021-09-28not-pretrained-R-' + str(R) +'-peak-10000.0')\n",
    "    rec_test_output = matfile['rec_test_output']\n",
    "    \n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(9):\n",
    "        ax = fig.add_subplot(3, 3, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        true = tf.expand_dims(true, -1)\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('R =', R)\n",
    "    ssim_mean3.append(ssim_temp[0])\n",
    "    ssim_std3.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54ce6a65",
   "metadata": {},
   "source": [
    "# not generative pretrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a455ced",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# not generative pretrained\n",
    "ssim_mean4 = []\n",
    "ssim_std4 = []\n",
    "\n",
    "photon_level = 1e4\n",
    "\n",
    "for R in [0.25, 0.5, 1, 2]:\n",
    "    pcc_temp = []\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    matfile = loadmat('2021-09-27pretrained-R-' + str(R) +'-peak-10000.0')\n",
    "    rec_test_output = matfile['rec_test_output']\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(9):\n",
    "        ax = fig.add_subplot(3, 3, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        true = tf.expand_dims(true, -1)\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('R =', R)\n",
    "    ssim_mean4.append(ssim_temp[0])\n",
    "    ssim_std4.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebab34cd",
   "metadata": {},
   "source": [
    "# End-to-End"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "560a1bdd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# End-to-End\n",
    "ssim_mean5 = []\n",
    "ssim_std5 = []\n",
    "\n",
    "photon_level = 1e4\n",
    "\n",
    "for R in [0.25, 0.5, 1, 2]:\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "\n",
    "    rec_test_output = np.load('End-to-End-test-output-R-' + str(R) + '-photon-10000.0.npy')\n",
    "    \n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(9):\n",
    "        ax = fig.add_subplot(3, 3, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        true = tf.expand_dims(true, -1)\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('R =', R)\n",
    "    ssim_mean5.append(ssim_temp[0])\n",
    "    ssim_std5.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1bd28f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_global_phase(im1, im2):\n",
    "    \"\"\"correct the global phase factor for iteartive reconstruction\n",
    "    \"\"\"\n",
    "    im1 = im1.reshape(1,-1)\n",
    "    im2 = im2.reshape(1,-1)\n",
    "    \n",
    "    mean1 = np.mean(im1)\n",
    "    mean2 = np.mean(im2)\n",
    "    \n",
    "    b = mean2 - mean1\n",
    "    \n",
    "    return (im1 + b).reshape(256, 256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31b23219",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "photon_level = 1e4\n",
    "lr = 1\n",
    "\n",
    "for r, R in enumerate([0.25, 0.5, 1, 2]):\n",
    "    for i, iters in enumerate([1]):\n",
    "        iter_input = np.load('./data/Simulated/R_Sweep/test-reconstruction-R-%0.2f-phperpix-%d-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, iters, lr)).astype(np.float32)\n",
    "        \n",
    "        mean1 = np.mean(iter_input, axis=(1, 2))\n",
    "        mean2 = np.mean(test_images, axis=(1, 2))\n",
    "        b1 = mean2 - mean1\n",
    "        print(b)\n",
    "        \n",
    "        plt.rcParams['figure.figsize'] = [20, 20]\n",
    "        fig = plt.figure()\n",
    "        fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "        for i in range(9):\n",
    "            ax = fig.add_subplot(3, 3, i +1)\n",
    "            ax.axes.xaxis.set_visible(False)\n",
    "            ax.axes.yaxis.set_visible(False)\n",
    "            plt.imshow(iter_input[i] + b1[i], cmap = 'viridis')\n",
    "            plt.clim(np.min(test_images), np.max(test_images))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "963bb3c7",
   "metadata": {},
   "source": [
    "# Iterative reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f20b8651",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "iterative_std = []\n",
    "iterative_ssim = []\n",
    "photon_level = 1e4\n",
    "lr = 0.5\n",
    "\n",
    "for r, R in enumerate([0.25, 0.5, 1, 2]):\n",
    "    for i, iters in enumerate([100]):\n",
    "        iter_input = np.load('./data/Simulated/R_Sweep/test-reconstruction-R-%0.2f-phperpix-%d-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, iters, lr)).astype(np.float32)\n",
    "        \n",
    "        plt.rcParams['figure.figsize'] = [20, 20]\n",
    "        fig = plt.figure()\n",
    "        fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "        for i in range(9):\n",
    "            ax = fig.add_subplot(3, 3, i +1)\n",
    "            ax.axes.xaxis.set_visible(False)\n",
    "            ax.axes.yaxis.set_visible(False)\n",
    "            plt.imshow(iter_input[i], cmap = 'viridis')\n",
    "            plt.clim(np.min(test_images), np.max(test_images))\n",
    "\n",
    "        iterloss_ssim_list = []\n",
    "        for gen, true in tqdm(zip(iter_input, test_images)):\n",
    "            \n",
    "            gen = find_global_phase(gen, true)\n",
    "            \n",
    "            gen = np.clip(gen, np.min(true), np.max(true))\n",
    "\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            ssim_data = ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            iterloss_ssim_list.append(ssim_data)\n",
    "\n",
    "        iterative_ssim.append(np.mean(iterloss_ssim_list))\n",
    "        iterative_std.append(np.std(iterloss_ssim_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09a6726e",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterative_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b225c9dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ssim_mean1, ssim_mean2, ssim_mean3, ssim_mean4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0cc1f83",
   "metadata": {},
   "source": [
    "# Plotting the R sweep results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5196efe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cycler import cycler\n",
    "import matplotlib as mpl\n",
    "\n",
    "plt.style.use('ggplot')\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "mpl.rcParams['axes.prop_cycle'] = cycler(color=['r', 'g', 'b', 'chocolate', 'olive'])\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "\n",
    "labels = ['R = 0.25', 'R = 0.5', 'R = 1', 'R = 2']\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.12  # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "rects1 = ax.bar(x - width * 3/2, ssim_mean4, width, yerr=ssim_std4, label='Non Generative', ecolor='black', capsize=5)\n",
    "rects2 = ax.bar(x - width * 1/2, ssim_mean1, width, yerr=ssim_std1, label='Generative', ecolor='black', capsize=5)\n",
    "#rects3 = ax.bar(x + width * 0, ssim_mean2, width, yerr=ssim_std2, label='Generative-Pretrain', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 1/2, iterative_ssim, width, yerr=iterative_std, label='Iterative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 3/2, ssim_mean5, width, yerr=ssim_std5, label='End-to-End', ecolor='black', capsize=5)\n",
    "\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('MS-SSIM', fontsize=35)\n",
    "#ax.set_title('R=0.25, 0.5, 1, 2' + ' with 1e3 photon and ImageNet4K', fontsize=25)\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels, fontsize=25)\n",
    "ax.tick_params(axis='both', labelsize=35)\n",
    "ax.legend(fontsize=35, loc='upper right')\n",
    "ax.patch.set_edgecolor('black')  \n",
    "ax.patch.set_linewidth('3')\n",
    "\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "#autolabel(rects1)\n",
    "#autolabel(rects2)\n",
    "fig.tight_layout()\n",
    "plt.ylim(0., 1.08)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5138a3ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.5\n",
    "\n",
    "iter_ssim = []\n",
    "\n",
    "for R in [0.25, 0.5, 1, 2]:\n",
    "    for photon_level in [1e4]:\n",
    "        for i, iters in enumerate([1, 5, 10, 20, 30, 40, 50, 60]):\n",
    "            iter_input = np.load('./data/Simulated/R_Sweep/test-reconstruction-R-%0.2f-phperpix-%d-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, iters, lr)).astype(np.float32)\n",
    "\n",
    "            iterloss_ssim_list = []\n",
    "            for gen, true in zip(iter_input, test_images):\n",
    "                \n",
    "                gen = find_global_phase(gen, true)\n",
    "                \n",
    "                gen = np.clip(gen, np.min(true), np.max(true))\n",
    "\n",
    "                true = tf.expand_dims(true, -1)\n",
    "                gen = tf.expand_dims(gen, -1)\n",
    "                \n",
    "                ssim_data = ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "                iterloss_ssim_list.append(ssim_data)\n",
    "                \n",
    "            iter_ssim.append(np.mean(iterloss_ssim_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fd556c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(iter_ssim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea680377",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MS-SSIM for the best deep-k-learning approach, this is not geneartive and not pretrained results\n",
    "best_deep_k_learning = ssim_mean4\n",
    "# the time in (ms) for deep learning model to process one object\n",
    "deep_k_learning_time = [4.45, 4.42, 4.43, 4.42]\n",
    "# the time to generate the approximant object, the sum is the total runtime for deep k-learning framework\n",
    "for i, R in enumerate([0.25, 0.5, 1, 2]):\n",
    "    runtime_array = np.load('./data/Simulated/R_Sweep/runtimes-R-%0.2f-phperpix-%d-lr-0.50.npy'% (R, 1e4))\n",
    "    deep_k_learning_time[i] += runtime_array[1, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29aa44ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "deep_k_learning_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8696bc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "mpl.rcParams['axes.prop_cycle'] = cycler(color=['g', 'b', 'chocolate', 'darkorange'])\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "annotations=['R = 0.25', 'R = 0.5', 'R = 1', 'R = 2']\n",
    "\n",
    "marker = ['x', '*', '.', '>']\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "lower_bound = []\n",
    "upper_bound = []\n",
    "\n",
    "for i, R in enumerate([0.25, 0.5, 1, 2]):\n",
    "    runtime_array = np.load('./data/Simulated/R_Sweep/runtimes-R-%0.2f-phperpix-%d-lr-0.50.npy'% (R, 1e4))\n",
    "    #print(runtime_array[1])\n",
    "    idx = 7\n",
    "    for j, quality in enumerate(iter_ssim[8*i + 1: 8*i +8]):\n",
    "        if quality >= best_deep_k_learning[i]:\n",
    "            idx = j - 1\n",
    "#             print(idx)\n",
    "#             print(\"lower bound ratio:\", runtime_array[1, idx+1]/deep_k_learning_time[i])\n",
    "#             lower_bound.append(runtime_array[1, idx+1]/deep_k_learning_time[i])\n",
    "#             print(\"upper bound ratio:\", runtime_array[1, idx+2]/deep_k_learning_time[i])\n",
    "#             upper_bound.append(runtime_array[1, idx+2]/deep_k_learning_time[i])\n",
    "#             break\n",
    "    #idx +=1 \n",
    "    ax.plot(runtime_array[1, 1:idx+1], iter_ssim[8*i + 1: 8*i+idx+1], '--' + marker[i], \n",
    "             label = \"Iterative \" + annotations[i], markersize=29)\n",
    "    \n",
    "plt.scatter(deep_k_learning_time, best_deep_k_learning, marker = 'o', color='r', s=380, label = \"Deep k-Learning\")\n",
    "\n",
    "for i, label in enumerate(annotations):\n",
    "    plt.annotate(label, (deep_k_learning_time[i], best_deep_k_learning[i]), fontsize=35, color='r')\n",
    "\n",
    "plt.legend(loc='lower right', fontsize=36)\n",
    "plt.xlabel('Runtime in millisecond', fontsize=35)\n",
    "plt.ylabel('MS-SSIM', fontsize=35)\n",
    "plt.ylim(0, 1)\n",
    "#plt.xlim(0, 100)\n",
    "plt.xscale('log')\n",
    "#plt.title('Runtime Comparison', fontsize=35)\n",
    "ax.patch.set_edgecolor('black')  \n",
    "ax.patch.set_linewidth('3')\n",
    "ax.tick_params(axis='both', labelsize=35)\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "839a3397",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(iter_ssim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8461f6e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the prefix for plotting noise sweep resulst\n",
    "prefix = data_folder_prefix + '/Simulated/Fixed_R_Noise_Sweep/'\n",
    "# load the MS-SSIM loss function\n",
    "ssim = tf.image.ssim_multiscale"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "409f6413",
   "metadata": {},
   "source": [
    "# Repeat for noise sweep results. Each cell below get the SSIM values for different photon level conditions using different framework"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1808ba6d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# generative not pretrained\n",
    "pcc_mean11 = []\n",
    "ssim_mean11 = []\n",
    "ssim_std11 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100.0, 1e3]):\n",
    "    pic = []\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]\n",
    "    # loop over all the alpha values\n",
    "    for alpha in alphas:\n",
    "        # load the generative not pretrained network results\n",
    "        empty = 0\n",
    "        for string in ['2021-09-27', '2021-09-28', '2021-09-29', '2021-09-30']:\n",
    "            #print(str(string))\n",
    "            if os.path.isfile(str(string) + '-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level)):\n",
    "                matfile = loadmat(str(string) + '-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))\n",
    "                empty += 1\n",
    "                \n",
    "        if empty != 1:\n",
    "            print(\"missing files for R-%0.2f-alpha-%0.2f-photon-%0.2f\" %(R, alpha, photon_level))\n",
    "            break\n",
    "        rec_test_output = matfile['rec_test_output']\n",
    "\n",
    "        pic.append(rec_test_output)\n",
    "    \n",
    "        loss_ssim_list = []\n",
    "        for gen, true in zip(rec_test_output, test_images):\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            loss_ssim_list.append(ssim_data)   \n",
    "        \n",
    "        ssim_temp.append(np.mean(loss_ssim_list))\n",
    "        std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    idx = np.argmax(ssim_temp)\n",
    "    print('the best alpha is:', alphas[idx])\n",
    "    ssim_mean11.append(ssim_temp[idx])\n",
    "    ssim_std11.append(std_temp[idx])\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(pic[idx][i], cmap = 'viridis')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4a37150",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# generative pretrained\n",
    "ssim_mean22 = []\n",
    "ssim_std22 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100.0, 1e3]):\n",
    "    \n",
    "    pic = []\n",
    "    \n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]\n",
    "    \n",
    "    # loop over all the alpha values\n",
    "    for alpha in alphas:\n",
    "        # load the generative not pretrained network results\n",
    "        empty = 0\n",
    "        for string in ['2021-09-27', '2021-09-28', '2021-09-29', '2021-09-30']:\n",
    "            #print(str(string) + '-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))\n",
    "            if os.path.isfile(str(string) + '-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level)):\n",
    "                #print(\"good\")\n",
    "                matfile = loadmat(str(string) + '-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))\n",
    "                empty += 1\n",
    "                \n",
    "        if empty != 1:\n",
    "            print(\"missing files for R-%0.2f-alpha-%0.2f-photon-%0.2f\" %(R, alpha, photon_level))\n",
    "            break\n",
    "        rec_test_output = matfile['rec_test_output']\n",
    "\n",
    "        pic.append(rec_test_output)\n",
    "\n",
    "        loss_ssim_list = []\n",
    "        for gen, true in zip(rec_test_output, test_images):\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            loss_ssim_list.append(ssim_data)   \n",
    "        \n",
    "        ssim_temp.append(np.mean(loss_ssim_list))\n",
    "        std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    idx = np.argmax(ssim_temp)\n",
    "    print('the best alpha is:', alphas[idx])\n",
    "    ssim_mean22.append(ssim_temp[idx])\n",
    "    ssim_std22.append(std_temp[idx])\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(pic[idx][i], cmap = 'viridis')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cad80e7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# not generative not pretrained\n",
    "ssim_mean33 = []\n",
    "ssim_std33 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    matfile = loadmat('2021-09-29not-pretrained-R-0.5-peak-' + str(photon_level) + '.mat')\n",
    "    rec_test_output = matfile['rec_test_output']\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        true = tf.expand_dims(true, -1)\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    ssim_mean33.append(ssim_temp[0])\n",
    "    ssim_std33.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa537174",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# not generative pretrained\n",
    "ssim_mean44 = []\n",
    "ssim_std44 = []\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "    \n",
    "    matfile = loadmat('2021-09-29pretrained-R-0.5-peak-' + str(photon_level) + '.mat')\n",
    "    rec_test_output = matfile['rec_test_output']\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        true = tf.expand_dims(true, -1)\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    ssim_mean44.append(ssim_temp[0])\n",
    "    ssim_std44.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07b3dd64",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# End-to-End\n",
    "ssim_mean55 = []\n",
    "ssim_std55 = []\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    ssim_temp = []\n",
    "    std_temp = []\n",
    "\n",
    "    rec_test_output = np.load('End-to-End-test-output-R-0.5-photon-' + str(photon_level) + '.npy')\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = [20, 20]\n",
    "    fig = plt.figure()\n",
    "    fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "    for i in range(16):\n",
    "        ax = fig.add_subplot(4, 4, i +1)\n",
    "        ax.axes.xaxis.set_visible(False)\n",
    "        ax.axes.yaxis.set_visible(False)\n",
    "        plt.imshow(rec_test_output[i], cmap = 'viridis')\n",
    "\n",
    "    loss_ssim_list = []\n",
    "    for gen, true in zip(rec_test_output, test_images):\n",
    "        true = tf.expand_dims(true, -1)\n",
    "        gen = tf.expand_dims(gen, -1)\n",
    "        \n",
    "        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "        loss_ssim_list.append(ssim_data)   \n",
    "\n",
    "    ssim_temp.append(np.mean(loss_ssim_list))\n",
    "    std_temp.append(np.std(loss_ssim_list))\n",
    "        \n",
    "    print('photon_level =', photon_level)\n",
    "    ssim_mean55.append(ssim_temp[0])\n",
    "    ssim_std55.append(std_temp[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c228ba8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# getting plots for input approximants\n",
    "lr = 1\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    for i, iters in enumerate([1]):\n",
    "        iter_input = np.load(prefix + 'test-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, iters, lr)).astype(np.float32)\n",
    "        \n",
    "        iterloss_ssim_list = []\n",
    "        for gen, true in zip(iter_input, test_images):\n",
    "            #gen = np.expand_dims(gen, axis=-1)\n",
    "            #gen = denoise_tv_chambolle(gen, weight=0.1)[..., 0]\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            ssim_data = ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            iterloss_ssim_list.append(ssim_data)\n",
    "            \n",
    "        plt.rcParams['figure.figsize'] = [20, 20]\n",
    "        fig = plt.figure()\n",
    "        fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "        for i in range(16):\n",
    "            ax = fig.add_subplot(4, 4, i +1)\n",
    "            ax.axes.xaxis.set_visible(False)\n",
    "            ax.axes.yaxis.set_visible(False)\n",
    "            plt.imshow(iter_input[i], cmap = 'viridis')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb490c55",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "iterative_std1 = []\n",
    "iterative_ssim1 = []\n",
    "\n",
    "lr = 0.5\n",
    "R = 0.5\n",
    "\n",
    "for idx, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    for i, iters in enumerate([1000]):\n",
    "        iter_input = np.load(prefix + 'test-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, iters, lr)).astype(np.float32)\n",
    "        \n",
    "        iterloss_ssim_list = []\n",
    "        for gen, true in zip(iter_input, test_images):\n",
    "            #gen = np.expand_dims(gen, axis=-1)\n",
    "            #gen = denoise_tv_chambolle(gen, weight=0.1)[..., 0]\n",
    "            gen = find_global_phase(gen, true)\n",
    "            \n",
    "            gen = np.clip(gen, np.min(true), np.max(true))\n",
    "\n",
    "            true = tf.expand_dims(true, -1)\n",
    "            gen = tf.expand_dims(gen, -1)\n",
    "            \n",
    "            ssim_data = ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "            iterloss_ssim_list.append(ssim_data)\n",
    "            \n",
    "        plt.rcParams['figure.figsize'] = [20, 20]\n",
    "        fig = plt.figure()\n",
    "        fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "        for i in range(16):\n",
    "            ax = fig.add_subplot(4, 4, i +1)\n",
    "            ax.axes.xaxis.set_visible(False)\n",
    "            ax.axes.yaxis.set_visible(False)\n",
    "            plt.imshow(iter_input[i], cmap = 'viridis')\n",
    "\n",
    "        iterative_ssim1.append(np.mean(iterloss_ssim_list))\n",
    "        iterative_std1.append(np.std(iterloss_ssim_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0693b257",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ssim_mean11, ssim_mean22, ssim_mean33, ssim_mean44)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2528faed",
   "metadata": {},
   "source": [
    "# Generating noise sweep plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85c1d8bb",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from cycler import cycler\n",
    "import matplotlib as mpl\n",
    "\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "mpl.rcParams['axes.prop_cycle'] = cycler(color=['r', 'g', 'b', 'chocolate', 'olive'])\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "\n",
    "labels = ['10$^0$ photons', '10$^1$ photons', '10$^2$ photons', '10$^3$ photons']\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.13 # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "rects1 = ax.bar(x - width * 3/2, ssim_mean44, width, yerr=ssim_std44, label='Non-Generative', ecolor='black', capsize=5)\n",
    "#rects2 = ax.bar(x - width * 1, ssim_mean33, width, yerr=ssim_std33, label='Non-Generative-Not-Pretrain', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x - width * 1/2, ssim_mean11, width, yerr=ssim_std11, label='Generative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 1/2, iterative_ssim1, width, yerr=iterative_std1, label='Iterative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 3/2, ssim_mean55, width, yerr=ssim_std55, label='End-to-End', ecolor='black', capsize=5)\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('MS-SSIM', fontsize=35)\n",
    "#ax.set_title('R=0.5' + ' with Different Poission Noise', fontsize=40)\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels, fontsize=25)\n",
    "ax.tick_params(axis='both', labelsize=35)\n",
    "ax.legend(fontsize=33)\n",
    "ax.patch.set_edgecolor('black')  \n",
    "ax.patch.set_linewidth('3')\n",
    "\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "#autolabel(rects1)\n",
    "#autolabel(rects2)\n",
    "fig.tight_layout()\n",
    "plt.ylim([0, 1])\n",
    "#plt.yscale('log')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fcbe94e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cycler import cycler\n",
    "import matplotlib as mpl\n",
    "\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "mpl.rcParams['axes.prop_cycle'] = cycler(color=['r', 'g', 'b', 'chocolate', 'olive'])\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "\n",
    "labels = np.flip(['10$^0$ photons', '10$^1$ photons', '10$^2$ photons', '10$^3$ photons'])\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.13 # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "rects1 = ax.bar(x - width * 3/2, np.flip(ssim_mean44), width, yerr=np.flip(ssim_std44), label='Non-Generative', ecolor='black', capsize=5)\n",
    "#rects2 = ax.bar(x - width * 1, ssim_mean33, width, yerr=ssim_std33, label='Non-Generative-Not-Pretrain', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x - width * 1/2, np.flip(ssim_mean11), width, yerr=np.flip(ssim_std11), label='Generative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 1/2, np.flip(iterative_ssim1), width, yerr=np.flip(iterative_std1), label='Iterative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 3/2, np.flip(ssim_mean55), width, yerr=np.flip(ssim_std55), label='End-to-End', ecolor='black', capsize=5)\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('MS-SSIM', fontsize=35)\n",
    "#ax.set_title('R=0.5' + ' with Different Poission Noise', fontsize=40)\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels, fontsize=25)\n",
    "ax.tick_params(axis='both', labelsize=35)\n",
    "ax.legend(fontsize=39)\n",
    "ax.patch.set_edgecolor('black')  \n",
    "ax.patch.set_linewidth('3')\n",
    "\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "#autolabel(rects1)\n",
    "#autolabel(rects2)\n",
    "fig.tight_layout()\n",
    "plt.ylim([0, 1])\n",
    "#plt.yscale('log')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d70b9b7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# append noise sweep results with 10^4 photon case for R = 0.5\n",
    "ssim_mean44.append(ssim_mean4[1])\n",
    "ssim_std44.append(ssim_std4[1])\n",
    "ssim_mean33.append(ssim_mean3[1])\n",
    "ssim_std33.append(ssim_std3[1])\n",
    "ssim_mean22.append(ssim_mean2[1])\n",
    "ssim_std22.append(ssim_std2[1])\n",
    "iterative_ssim1.append(iterative_ssim[1])\n",
    "iterative_std1.append(iterative_std[1])\n",
    "ssim_mean55.append(ssim_mean5[1])\n",
    "ssim_std55.append(ssim_std5[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87237ad3",
   "metadata": {},
   "source": [
    "# R=0.5, noise sweep from 1 to 10^4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e357b89",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cycler import cycler\n",
    "import matplotlib as mpl\n",
    "\n",
    "\n",
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "mpl.rcParams['axes.prop_cycle'] = cycler(color=['r', 'g', 'b', 'chocolate', 'olive'])\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "\n",
    "labels = ['10$^0$ photons', '10$^1$ photons', '10$^2$ photons', '10$^3$ photons', '10$^4$ photons']\n",
    "\n",
    "x = np.arange(len(labels))  # the label locations\n",
    "width = 0.13 # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "rects1 = ax.bar(x - width * 3/2, ssim_mean44, width, yerr=ssim_std44, label='Non-Generative', ecolor='black', capsize=5)\n",
    "#rects2 = ax.bar(x - width * 1/2, ssim_mean33, width, yerr=ssim_std33, label='Non-Generative-Not-Pretrain', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x - width * 1/2, ssim_mean22, width, yerr=ssim_std22, label='Generative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 1/2, iterative_ssim1, width, yerr=iterative_std1, label='Iterative', ecolor='black', capsize=5)\n",
    "rects3 = ax.bar(x + width * 3/2, ssim_mean55, width, yerr=ssim_std55, label='End-to-End', ecolor='black', capsize=5)\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('MS-SSIM', fontsize=35)\n",
    "#ax.set_title('R=0.5' + ' with Different Poission Noise', fontsize=40)\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(labels, fontsize=25)\n",
    "ax.tick_params(axis='both', labelsize=35)\n",
    "ax.legend(fontsize=38)\n",
    "ax.patch.set_edgecolor('black')  \n",
    "ax.patch.set_linewidth('3')\n",
    "\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom')\n",
    "\n",
    "#autolabel(rects1)\n",
    "#autolabel(rects2)\n",
    "fig.tight_layout()\n",
    "plt.ylim([0, 1.05])\n",
    "#plt.yscale('log')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28a6974b",
   "metadata": {},
   "source": [
    "# Getting the MS-SSIM at each iteration for iteartive reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0818693",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.5\n",
    "\n",
    "iter_ssim = []\n",
    "\n",
    "for R in [0.5]:\n",
    "    for photon_level in np.flip([1, 10, 100, 1e3]):\n",
    "        for i, iters in enumerate([1, 5, 10, 50, 100, 200, 300, 500, 1000]):\n",
    "            iter_input = np.load(prefix + 'test-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, iters, lr)).astype(np.float32)\n",
    "\n",
    "            iterloss_ssim_list = []\n",
    "            for gen, true in zip(iter_input, test_images):\n",
    "                #gen = np.expand_dims(gen, axis=-1)\n",
    "                #gen = denoise_tv_chambolle(gen, weight=0.2)[..., 0]\n",
    "                #gen = exposure.equalize_hist(gen).astype(np.float32)\n",
    "                gen = find_global_phase(gen, true)\n",
    "                \n",
    "                gen = np.clip(gen, np.min(true), np.max(true))\n",
    "\n",
    "                true = tf.expand_dims(true, -1)\n",
    "                gen = tf.expand_dims(gen, -1)\n",
    "                \n",
    "                ssim_data = ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))\n",
    "                iterloss_ssim_list.append(ssim_data)\n",
    "                \n",
    "            iter_ssim.append(np.mean(iterloss_ssim_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33122a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MS-SSIM for the best deep-k-learning approach, this is not geneartive and not pretrained results\n",
    "best_deep_k_learning = ssim_mean44\n",
    "# the time in (ms) for deep learning model to process one object\n",
    "deep_k_learning_time = [4.45, 4.42, 4.43, 4.42]\n",
    "# the time to generate the approximant object, the sum is the total runtime for deep k-learning framework\n",
    "for i, photon_level in enumerate([1, 10, 100, 1e3]):\n",
    "    runtime_array = np.load(prefix + 'runtimes-R-0.50-phperpix-%0.2f-lr-0.50.npy'% photon_level)\n",
    "    deep_k_learning_time[i] += runtime_array[1, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d23b4096",
   "metadata": {},
   "source": [
    "# Generating plots for runtime comparsion at low photon conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8510338",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "mpl.rcParams['axes.prop_cycle'] = cycler(color=['g', 'b', 'chocolate', 'darkorange'])\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [20, 16]\n",
    "\n",
    "annotations=['10$^0$ photons', '10$^1$ photons', '10$^2$ photons', '10$^3$ photons']\n",
    "\n",
    "marker = ['x', '*', '.', '>']\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "for i, photon_level in np.flip(enumerate([1, 10, 100, 1e3])):\n",
    "    runtime_array = np.load(prefix + 'runtimes-R-0.50-phperpix-%0.2f-lr-0.50.npy'% photon_level)\n",
    "    ax.plot(runtime_array[1, 1:], iter_ssim[9*i + 1: 9*i+9], '--' + marker[i], \n",
    "             label = \"Iterative \" + annotations[i], markersize=29)\n",
    "    \n",
    "plt.scatter(deep_k_learning_time, best_deep_k_learning, marker = 'o', color='r', s=380, label = \"Deep k-Learning\")\n",
    "\n",
    "for i, label in enumerate(annotations):\n",
    "    if i == 3:\n",
    "        plt.annotate(label, (deep_k_learning_time[i], best_deep_k_learning[i]), xytext =(5.5, 0.895), fontsize=35, color='r')\n",
    "    else:\n",
    "        plt.annotate(label, (deep_k_learning_time[i], best_deep_k_learning[i]), fontsize=35, color='r')\n",
    "\n",
    "plt.legend(loc='lower right', fontsize=35)\n",
    "plt.xlabel('Runtime in millisecond', fontsize=35)\n",
    "plt.ylabel('MS-SSIM', fontsize=35)\n",
    "plt.ylim(0, 1.)\n",
    "plt.xlim(2, 300)\n",
    "plt.xscale('log')\n",
    "#plt.title('Runtime Comparison', fontsize=35)\n",
    "ax.patch.set_edgecolor('black')  \n",
    "ax.patch.set_linewidth('3')\n",
    "ax.tick_params(axis='both', labelsize=35)\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63487b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix = data_folder_prefix + '/Simulated/Fixed_R_Noise_Sweep/'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a9452ee",
   "metadata": {},
   "source": [
    "# Generating plots for visual comparsion at low photon conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21018202",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = [20+4, 20]\n",
    "fig = plt.figure()\n",
    "fig.subplots_adjust(hspace=.05, wspace=.05)\n",
    "\n",
    "R = 0.5\n",
    "\n",
    "for i in range(24):\n",
    "    photon_level = [1e3, 100, 10, 1][i//6]\n",
    "    alpha = np.flip([0.25, 0.25, 0.5, 0.25])[i//6]\n",
    "    \n",
    "    test_images = np.load(prefix + 'test_images-R-0.50.npy')\n",
    "    # get the phase of the complex objects\n",
    "    test_images = np.angle(test_images)\n",
    "    \n",
    "    prox_1 = np.load(prefix + 'test-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, 1, 1)).astype(np.float32)\n",
    "    \n",
    "    mean1 = np.mean(prox_1, axis=(1, 2))\n",
    "    mean2 = np.mean(test_images, axis=(1, 2))\n",
    "    b1 = mean2 - mean1\n",
    "    #prox_1 = prox_1 + b\n",
    "    \n",
    "    prox_100 = np.load(prefix + 'test-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' \n",
    "                             % (R, photon_level, 100, 0.5)).astype(np.float32)\n",
    "    \n",
    "    mean1 = np.mean(prox_100, axis=(1, 2))\n",
    "    mean2 = np.mean(test_images, axis=(1, 2))\n",
    "    b100 = mean2 - mean1\n",
    "    #prox_100 = prox_100 + b\n",
    "    \n",
    "    e2e = np.load('End-to-End-test-output-R-0.5-photon-' + str(photon_level) + '.npy')\n",
    "    \n",
    "    not_gen = loadmat('2021-09-29pretrained-R-0.5-peak-' + str(photon_level) + '.mat')['rec_test_output']\n",
    "    gen_not_pre = loadmat('2021-09-28-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))['rec_test_output']\n",
    "    \n",
    "    idx = 1\n",
    "     \n",
    "    ax = fig.add_subplot(5, 6, i +1)\n",
    "    ax.axes.xaxis.set_visible(False)\n",
    "    ax.axes.yaxis.set_visible(False)\n",
    "    if i%6 == 0 :\n",
    "        plt.imshow(np.rot90(prox_1[idx] + b1[idx], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 1 :\n",
    "        plt.imshow(np.rot90(prox_100[idx] + b100[idx], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 2 :\n",
    "        plt.imshow(np.rot90(e2e[idx], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 3 :\n",
    "        plt.imshow(np.rot90(not_gen[idx], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 4 :\n",
    "        plt.imshow(np.rot90(gen_not_pre[idx], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))\n",
    "    if i%6 == 5 :\n",
    "        plt.imshow(np.rot90(test_images[idx], 3), cmap='viridis')\n",
    "        plt.clim(np.min(test_images), np.max(test_images))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa1c5f4f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
